{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第 12 章 自然语言生成实战"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 12.1 LSTM 写诗"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 12.1.3 实现 LSTM 写诗"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['公子申敬爱，携朋玩物华。人是平阳客，地即石崇家。水文生旧浦，风色满新花。日暮连归骑，长川照晚霞。', '高门引冠盖，下客抱支离。绮席珍羞满，文场翰藻摛。蓂华雕上月，柳色蔼春池。日斜归戚里，连骑勒金羁。', '今夜可怜春，河桥多丽人。宝马金为络，香车玉作轮。连手窥潘掾，分头看洛神。重城自不掩，出向小平津。']\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import glob\n",
    "import json\n",
    "import random\n",
    "import operator\n",
    "import collections\n",
    "\n",
    "from typing import List, Dict\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "\n",
    "# 只保留五言绝句\n",
    "def should_keep(paragraphs: List[str]):\n",
    "    return all([len(par) == 12 for par in paragraphs])\n",
    "\n",
    "# 读取数据\n",
    "def read_all_data(path: str):\n",
    "    poems = []\n",
    "    files = glob.glob(os.path.join(path, 'poet.tang.*.json'))\n",
    "    for file in files:\n",
    "        file_data = json.load(open(file, 'r'))\n",
    "        for item in file_data:\n",
    "            if should_keep(item['paragraphs']):\n",
    "                poem = ''.join(item['paragraphs'])\n",
    "                poems.append(poem)\n",
    "    return poems\n",
    "\n",
    "poems = read_all_data('data/poetry')\n",
    "# 为了加快训练，这里只取了 10000 条诗，可以根据自己资源增加或者减少\n",
    "poems = poems[:10000]\n",
    "print(poems[:3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "class Processor(object):\n",
    "\n",
    "    def build_token_dict(self, corpus: List[List[str]]):\n",
    "        \"\"\"\n",
    "        构建 token 字典，这个方法将会遍历分词后的语料，构建一个标记频率字典和标记与索引的映射字典\n",
    "\n",
    "        Args:\n",
    "            corpus: 所有分词后的语料\n",
    "        \"\"\"\n",
    "        token2idx = {\n",
    "            '<PAD>': 0,\n",
    "            '<UNK>': 1,\n",
    "            '<BOS>': 2,\n",
    "            '<EOS>': 3\n",
    "        }\n",
    "\n",
    "        token2count = {}\n",
    "        for sentence in corpus:\n",
    "            for token in sentence:\n",
    "                count = token2count.get(token, 0)\n",
    "                token2count[token] = count + 1\n",
    "        # 按照词频降序排序\n",
    "        sorted_token2count = sorted(token2count.items(),\n",
    "                                    key=operator.itemgetter(1),\n",
    "                                    reverse=True)\n",
    "        token2count = collections.OrderedDict(sorted_token2count)\n",
    "\n",
    "        for token in token2count.keys():\n",
    "            if token not in token2idx:\n",
    "                token2idx[token] = len(token2idx)\n",
    "        return token2idx, token2count\n",
    "\n",
    "    @staticmethod\n",
    "    def numerize_sequences(sequence: List[str],\n",
    "                           token2index: Dict[str, int]) -> List[int]:\n",
    "        \"\"\"\n",
    "        将分词后的标记（token）数组转换成对应的索引数组\n",
    "        如 ['我', '想', '睡觉'] -> [10, 313, 233]\n",
    "\n",
    "        Args:\n",
    "            sequence: 分词后的标记数组\n",
    "            token2index: 索引词典\n",
    "        Returns: 输入数据对应的索引数组\n",
    "        \"\"\"\n",
    "        token_result = []\n",
    "        for token in sequence:\n",
    "            token_index = token2index.get(token)\n",
    "            if token_index is None:\n",
    "                token_index = token2index['<UNK>']\n",
    "            token_result.append(token_index)\n",
    "        return token_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p = Processor()\n",
    "# 这里我们对所有的诗做了基于字的分词，然后再构建词表\n",
    "p.token2idx, p.token2count = p.build_token_dict([list(seq) for seq in poems])\n",
    "# 由于我们这是文本生成，还需要一个索引到词的映射关系\n",
    "p.idx2token = dict([(v, k) for k,v in p.token2idx.items()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# 先定义一下两个全局变量，输入序列长度和批次大小\n",
    "INPUT_LEN = 6\n",
    "BATCH_SIZE = 500\n",
    "\n",
    "# 所有的诗整合为一个大字符串，方便后续遍历\n",
    "corpus = ''.join(poems)\n",
    "\n",
    "def data_generator():\n",
    "    t = 0\n",
    "    while True:\n",
    "        x_data = []\n",
    "        y_data = []\n",
    "        for i in range(BATCH_SIZE):\n",
    "            # 取出 t 到 t + INPUT_LEN 位置的字符串序列作为输入\n",
    "            x = corpus[t: t + INPUT_LEN]\n",
    "            # 取出 t + INPUT_LEN 位置的字符串作为输出\n",
    "            y = corpus[t + INPUT_LEN]\n",
    "\n",
    "            # 输入输出转换为数字\n",
    "            x_data.append(p.numerize_sequences(list(x), p.token2idx))\n",
    "            y_data.append(p.token2idx[y])\n",
    "\n",
    "            t += 1\n",
    "            # 当游标到了最后，从头开始遍历\n",
    "            if t + 1 >= len(corpus) - INPUT_LEN:\n",
    "                t = 0\n",
    "\n",
    "        x_data = np.array(x_data)\n",
    "        # 将输出序列转换为 one-hot 编码\n",
    "        y_data = to_categorical(y_data, len(p.token2idx))\n",
    "\n",
    "        yield x_data, y_data\n",
    "        \n",
    "# 初始化数据生成器\n",
    "# 如果想观察生成器每一步产生的数据，初始化生成器后调用 `next(gen)` 函数观察\n",
    "gen = data_generator()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding (Embedding)        (None, 6, 50)             294300    \n",
      "_________________________________________________________________\n",
      "lstm (LSTM)                  (None, 128)               91648     \n",
      "_________________________________________________________________\n",
      "dropout (Dropout)            (None, 128)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 5886)              759294    \n",
      "=================================================================\n",
      "Total params: 1,145,242\n",
      "Trainable params: 1,145,242\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "WARNING:tensorflow:From <ipython-input-5-bf7c91949fe5>:18: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use Model.fit, which supports generators.\n",
      "Epoch 1/10\n",
      "1259/1259 [==============================] - 84s 67ms/step - loss: 6.2021\n",
      "Epoch 2/10\n",
      "1259/1259 [==============================] - 88s 70ms/step - loss: 5.8193\n",
      "Epoch 3/10\n",
      "1259/1259 [==============================] - 89s 71ms/step - loss: 5.6345\n",
      "Epoch 4/10\n",
      "1259/1259 [==============================] - 84s 67ms/step - loss: 5.4634\n",
      "Epoch 5/10\n",
      "1259/1259 [==============================] - 84s 67ms/step - loss: 5.3225\n",
      "Epoch 6/10\n",
      "1259/1259 [==============================] - 84s 67ms/step - loss: 5.2167\n",
      "Epoch 7/10\n",
      "1259/1259 [==============================] - 84s 67ms/step - loss: 5.1337\n",
      "Epoch 8/10\n",
      "1259/1259 [==============================] - 96s 76ms/step - loss: 5.0641\n",
      "Epoch 9/10\n",
      "1259/1259 [==============================] - 99s 79ms/step - loss: 5.0027\n",
      "Epoch 10/10\n",
      "1259/1259 [==============================] - 107s 85ms/step - loss: 4.9474\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x13de92240>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "L = tf.keras.layers\n",
    "\n",
    "model = tf.keras.Sequential([\n",
    "    L.Embedding(input_dim=len(p.token2idx), output_dim=50, input_shape=(6, )),\n",
    "    L.LSTM(128),\n",
    "    L.Dropout(0.1),\n",
    "    L.Dense(len(p.token2idx), activation='softmax')\n",
    "])\n",
    "\n",
    "model.compile(optimizer='adam', loss='categorical_crossentropy')\n",
    "model.summary()\n",
    "\n",
    "\n",
    "# 每个 epoch 步数等于（整个语料序列长度 - 窗口长度）除以批次大小\n",
    "steps = (len(corpus) - INPUT_LEN - 1) // BATCH_SIZE\n",
    "model.fit_generator(gen,\n",
    "                    steps_per_epoch=steps,\n",
    "                    epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def sample(preds: np.ndarray, temperature: float = 1.0) -> int:\n",
    "    \"\"\"\n",
    "    使用 softmax 温度随机采样\n",
    "    当 temperature = 1.0 时，模型输出正常\n",
    "    当 temperature = 0.5 时，模型输出比较open\n",
    "    当 temperature = 1.5 时，模型输出比较保守\n",
    "\n",
    "    Args:\n",
    "        preds: 模型预测结果\n",
    "        temperature: softmax 温度\n",
    "    Returns:\n",
    "        采样结果\n",
    "    \"\"\"\n",
    "    preds = np.asarray(preds).astype('float64')\n",
    "    exp_preds = np.power(preds, 1. / temperature)\n",
    "    preds = exp_preds / np.sum(exp_preds)\n",
    "    pro = np.random.choice(range(len(preds)), 1, p=preds)\n",
    "    return int(pro.squeeze())\n",
    "\n",
    "def predict_next_char(input_seq: List[str],\n",
    "                      temperature: float = 1.0) -> str:\n",
    "    \"\"\"\n",
    "    输入序列，预测下一个字符\n",
    "\n",
    "    Args:\n",
    "        input_seq: 输入序列\n",
    "        temperature: softmax 温度\n",
    "    Returns:\n",
    "        下一个字符串\n",
    "    \"\"\"\n",
    "    if len(input_seq) < INPUT_LEN:\n",
    "        raise ValueError(f'seq length must large than {INPUT_LEN}')\n",
    "\n",
    "    input_seq = input_seq[-INPUT_LEN:]\n",
    "    input_tensor = p.numerize_sequences(input_seq, p.token2idx)\n",
    "    input_tensor = np.array([input_tensor])\n",
    "    preds = model.predict(input_tensor)[0]\n",
    "    pred_idx = sample(preds, temperature)\n",
    "    pred_char = p.idx2token[pred_idx]\n",
    "    return pred_char\n",
    "\n",
    "def pred_with_start(input_seq: List[str],\n",
    "                    temperature: float = 1.0) -> List[str]:\n",
    "    \"\"\"\n",
    "    以给定字符串作为开头写诗\n",
    "\n",
    "    Args:\n",
    "        input_seq: 诗开头字符串\n",
    "        temperature: softmax 温度\n",
    "    Returns:\n",
    "        生成的诗歌序列\n",
    "    \"\"\"\n",
    "    result = input_seq\n",
    "    # 如果长度不足，则随机取一首诗补全\n",
    "    if len(input_seq) < INPUT_LEN:\n",
    "        padding_poem = list(random.choice(poems))\n",
    "    else:\n",
    "        padding_poem = []\n",
    "\n",
    "    # 当序列中出现四个 。 或者序列长度超过 100 时候停止\n",
    "    # 100 这个限制主要是为了避免出现死循环\n",
    "    should_continue = True\n",
    "    while should_continue:\n",
    "        pred_char = predict_next_char(padding_poem + result, temperature)\n",
    "        result.append(pred_char)\n",
    "        if result.count('。') == 4 or len(result) > 100:\n",
    "            should_continue = False\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Temperature: 0.3\n",
      "冬日乘何处，春风落远时。春风吹落日，秋雨带寒风。惆怅南山上，归人独自怜。一年无事意，万里见长安。\n",
      "冬日长安见，朝风满客游。从君不可见，何必为相思。一生千里里，万里万里中。谁能当夜暮，犹复向云中。\n",
      "冬日日已长，夜云犹自适。何处归别人，时来不可见。君子不可怜，何人不可忘。何事无所见，此心无所思。\n",
      "\n",
      "Temperature: 0.6\n",
      "冬日生岁月，龙花过自流。月晴从月落，山色不归归。独与江浪白，秋来水鸟红。孤舟楚高塞，积露夜烟山。\n",
      "冬日相送处，日暮东山暮。孤竹静未疏，寒钟雨相引。江日未繁别，云山无任情。尧年西五重，簪悴万里难。\n",
      "冬日五十年，何事再三扬。身上若无道，我来无所心。从君最不惜，作者我何何。自闻花落水，不及凤凰州。\n",
      "\n",
      "Temperature: 1.0\n",
      "冬日临蝉骑，陂川照雪农。气酣明刃管，空质暗珊红。咫里栽名废，旧人饮手募。素功虽百沂，玉富慎犹剖。\n",
      "冬日忽应领，白云偏可唾。旧家俱用台，百事本元珪。今尚有回好，随天谁处家。静行复有意，大事所相惊。\n",
      "冬日相无寄，音生始有修。或须人苦卜，息别一泉鸣。时独俱成出，开歌共自归。中云不可见，见赏仍流索。\n",
      "\n",
      "Temperature: 1.2\n",
      "冬日云飞雨，荆川白日碑。云华陈月夜，花上犬毖垂。颦若临沙日，放陵满野舟。逢游望何在，非止移更夸。\n",
      "冬日怜应挂，风阴动无康。早蝉繁石鹊，冬稻接篮笺。隐盎东溪残，沟邻麋夜光。上朝孤掖客，然夏到江边。\n",
      "冬日剪袍劲，耕毫扫建舆。神黑商川阔，宋泉向老宽。自嫌飞策宴，古合谢繇翔。抱路经朝罢，琴心出客情。\n",
      "\n",
      "Temperature: 1.5\n",
      "冬日古来戍，离舟斜寂堂。恐重开骢寿，形滋游哲异。春陌外万物，漾旗仰荥城。憯裘开橤火，于渥凋伥药。\n",
      "冬日竟相永，滴欢回雒阴。暮寒灯景远，春雨浪鼠豪。暝雪俱即扇，㶉严称凤筝。槭峨楼萼迥，翛恍无人宿。\n",
      "冬日失知唧，帽与愿要迫。刘怅门学愚，度十剧陀彦。谁苦拍鳃木，野老裁酸兔。闻角涉大处，讴家出山山。\n"
     ]
    }
   ],
   "source": [
    "for temp in [0.3, 0.6, 1.0, 1.2, 1.5]:\n",
    "    print(f'\\nTemperature: {temp}')\n",
    "    for _ in range(3):\n",
    "        print(''.join(pred_with_start(['冬', '日'], temp)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def predict_hide(head_tokens: List[str],\n",
    "                 temperature: float = 1.0) -> List[str]:\n",
    "    \"\"\"\n",
    "    写藏头诗\n",
    "\n",
    "    Args:\n",
    "        head_tokens: 每一句的第一个字组成的数组\n",
    "        temperature: softmax 温度\n",
    "    Returns:\n",
    "        生成的诗歌序列\n",
    "    \"\"\"\n",
    "    padding_poem = list(random.choice(poems))\n",
    "    result = []\n",
    "\n",
    "    for i in range(4):\n",
    "        result.append(head_tokens[i])\n",
    "        sentence_end = False\n",
    "        while not sentence_end:\n",
    "            char = predict_next_char(padding_poem + result, temperature)\n",
    "            result.append(char)\n",
    "            if char == '。' or len(result) > 100:\n",
    "                sentence_end = True\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Temperature: 0.3\n",
      "机门高草里，兔脉落阳台。器为高名在，思浮此始悲。学之难恋恶，自醉又成衰。习来舟陵上，彻历烟天外。\n",
      "机占已分酒，深身更似情。器深芳海峻，携手在天前。学迹随节场，归丝对高蠹。习语去乡场，殊人碌轶记。\n",
      "机岂已相久，回南宁相知。器石于朱墀，寂寞五阳汉。学置六里运，争见一庭幽。习客阳菲路，贻程任此情。\n",
      "\n",
      "Temperature: 0.6\n",
      "机眄造道师，手充皆送潘。器财仙出菊，𫌇态桂微铺。学足本将戴，滥哉全者孟。习仁天圣远，及小鼎彝诗。\n",
      "机劝招台步，文歌发圣荣。器书军偃激，竞败瑶兴白。学季西北建，方令波碧残。习花休镐印，方贺诏清褒。\n",
      "机绿闾中识，香晴剑上然。器承佐忠大，贻说在冥茫。学就忆时处，逢城无似阑。习手即常制，羞令闲镆崭。\n",
      "\n",
      "Temperature: 1.0\n",
      "机兴服灵策，郑义降衰下。器复转联绁，空沾荷酒花。学婴劳暂看，烟壑空远传。习息同止麦，生涯著占之。\n",
      "机灵重幕浦，愿以胡江柏。器勃宫群籍，龙佳玉压膺。学书徒郁就，锵义可思过。习我刘凤末，遂欲歌山沃。\n",
      "机调当态远，驰州忘辛歌。器负金粱茎，□辉荆葛峩。学宽刊帝府，更此入皇天。习忠行远札，奄解咏新香。\n",
      "\n",
      "Temperature: 1.2\n",
      "机相愁返望，复往别来家。器舅随家府，其官忝一钧。学道耸难盛，弃患动方珍。习宠事王贵，负师奏武兵。\n",
      "机官起正去，鞍马心乘绝。器金闲布藻，尺马劳焙蠹。学名大肠为，怅何始与昔。习之徒逸伦，酺贱尚可泯。\n",
      "机心讵为色，书引妾何胸。器禄既一年，反将陶组宦。学旱五两铺，续农八二精。习成歌健愿，寻会曲生情。\n",
      "\n",
      "Temperature: 1.5\n",
      "机定共生明，群情宵且闲。器在疑爱羁，可见遶头木。学义且可见，深为则子病。习剑剸新琴，□诗示长肉。\n",
      "机鸣每寒黛，荆官自向春。器言新大梦，妆乐是登峰。学拜入临草，来行或逐鸥。习时能妄复，犹恐尔无伤。\n",
      "机从死重寿，卒实祸诸途。器提妄士论，薏别旋陵棱。学位物高省，平生岂一顾。习汉君未回，贼丈我徒同。\n"
     ]
    }
   ],
   "source": [
    "for temp in [0.3, 0.6, 1.0, 1.2, 1.5]:\n",
    "    print(f'\\nTemperature: {temp}')\n",
    "    for _ in range(3):\n",
    "        print(''.join(predict_hide(['机', '器', '学', '习'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
